UNSUPERVISED REPRESENTATION LEARNING WITH DEEP CONVOLUTIONAL GENERATIVE ADVERSARIAL NETWORKS - Alec Radford & Luke Metz https://arxiv.org/pdf/1511.06434.pdf
# Instalacija paketa
!pip install opencv-python
# paket za manipuliranje slika cv2
!pip install numpy
# paket koji sadrzi funkcije za manipuliranje nizova i ostalih objekata
!pip install matplotlib
# paket koji se koristi za prikaz grafova
!pip install torch
# PyTorch - open source paket za masinsko programiranje preko kojeg se kreira GAN model
!pip install torchvision
# Sadrži popularne baze podataka koji se koriste za testiranje
Requirement already satisfied: opencv-python in c:\users\user\appdata\roaming\python\python39\site-packages (4.6.0.66) Requirement already satisfied: numpy>=1.19.3 in c:\users\user\anaconda3\lib\site-packages (from opencv-python) (1.21.5) Requirement already satisfied: numpy in c:\users\user\anaconda3\lib\site-packages (1.21.5) Requirement already satisfied: matplotlib in c:\users\user\anaconda3\lib\site-packages (3.5.2) Requirement already satisfied: pyparsing>=2.2.1 in c:\users\user\anaconda3\lib\site-packages (from matplotlib) (3.0.9) Requirement already satisfied: fonttools>=4.22.0 in c:\users\user\anaconda3\lib\site-packages (from matplotlib) (4.25.0) Requirement already satisfied: numpy>=1.17 in c:\users\user\anaconda3\lib\site-packages (from matplotlib) (1.21.5) Requirement already satisfied: pillow>=6.2.0 in c:\users\user\anaconda3\lib\site-packages (from matplotlib) (9.2.0) Requirement already satisfied: python-dateutil>=2.7 in c:\users\user\anaconda3\lib\site-packages (from matplotlib) (2.8.2) Requirement already satisfied: packaging>=20.0 in c:\users\user\anaconda3\lib\site-packages (from matplotlib) (21.3) Requirement already satisfied: cycler>=0.10 in c:\users\user\anaconda3\lib\site-packages (from matplotlib) (0.11.0) Requirement already satisfied: kiwisolver>=1.0.1 in c:\users\user\anaconda3\lib\site-packages (from matplotlib) (1.4.2) Requirement already satisfied: six>=1.5 in c:\users\user\anaconda3\lib\site-packages (from python-dateutil>=2.7->matplotlib) (1.16.0) Requirement already satisfied: torch in c:\users\user\anaconda3\lib\site-packages (1.13.1) Requirement already satisfied: typing_extensions in c:\users\user\anaconda3\lib\site-packages (from torch) (4.3.0) Requirement already satisfied: torchvision in c:\users\user\anaconda3\lib\site-packages (0.14.1) Requirement already satisfied: typing_extensions in c:\users\user\anaconda3\lib\site-packages (from torchvision) (4.3.0) Requirement already satisfied: numpy in c:\users\user\anaconda3\lib\site-packages (from torchvision) (1.21.5) Requirement already satisfied: requests in c:\users\user\anaconda3\lib\site-packages (from torchvision) (2.28.1) Requirement already satisfied: torch==1.13.1 in c:\users\user\anaconda3\lib\site-packages (from torchvision) (1.13.1) Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in c:\users\user\anaconda3\lib\site-packages (from torchvision) (9.2.0) Requirement already satisfied: urllib3<1.27,>=1.21.1 in c:\users\user\anaconda3\lib\site-packages (from requests->torchvision) (1.26.11) Requirement already satisfied: charset-normalizer<3,>=2 in c:\users\user\anaconda3\lib\site-packages (from requests->torchvision) (2.0.4) Requirement already satisfied: certifi>=2017.4.17 in c:\users\user\anaconda3\lib\site-packages (from requests->torchvision) (2022.9.14) Requirement already satisfied: idna<4,>=2.5 in c:\users\user\anaconda3\lib\site-packages (from requests->torchvision) (3.3) Requirement already satisfied: tqdm in c:\users\user\anaconda3\lib\site-packages (4.64.1) Requirement already satisfied: colorama in c:\users\user\anaconda3\lib\site-packages (from tqdm) (0.4.5)
# Manipulacija slika
import os
import matplotlib.pyplot as plt
import numpy as np
import cv2
# Kreiranje neuronski mreža
import torch
import random
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as T
import torchvision.utils as vutils
import torchvision
from torch.autograd import Variable
# Prikaz rezultata
from tqdm import tqdm
import matplotlib.animation as animation
from IPython.display import HTML
%matplotlib inline
Treniranje mreže i generisanje slika vrši se koristeći 3 teme:
# Dataset
pokemon_data_directory = './datasets/pokemon/'
print('Broj podataka u pokemon datasetu: ' + str(len(os.listdir(pokemon_data_directory))))
pokemon_card = cv2.imread(pokemon_data_directory + os.listdir(pokemon_data_directory)[0])
print("Oblik slika iz skupa pokemoni: " + str(pokemon_card.shape) + ". Rezolucija: 256x256, RGB slika\n")
# plt.imshow(pokemon_card)
vijesti_data_directory = './datasets/vijesti/'
print('Broj podataka u vijesti datasetu: ' + str(len(os.listdir(vijesti_data_directory))))
vijest_article = cv2.imread(vijesti_data_directory + os.listdir(vijesti_data_directory)[0])
print("Oblik slika iz skupa vijesti: " + str(vijest_article.shape) + ". Rezolucija: 256x256, RGB slika\n")
lzn_data_directory = './datasets/lud_zbunjen_normalan/'
print('Broj podataka u lzn datasetu: ' + str(len(os.listdir(lzn_data_directory))))
lzn_image = cv2.imread(lzn_data_directory + os.listdir(lzn_data_directory)[0])
print("Oblik slika iz skupa lzn: " + str(lzn_image.shape) + ". Rezolucija: 256x256, RGB slika\n")
red_dress_directory = './datasets/red_dress/'
print('Broj podataka u haljini datasetu: ' + str(len(os.listdir(red_dress_directory))) + '\n')
simpsons_data_directory = './datasets/simpsons/'
print('Broj podataka u simpsons datasetu: ' + str(len(os.listdir(simpsons_data_directory))) + '\n')
# Definisanja skupa podataka za treniranje
training_dataset = simpsons_data_directory
Broj podataka u pokemon datasetu: 819 Oblik slika iz skupa pokemoni: (256, 256, 3). Rezolucija: 256x256, RGB slika Broj podataka u vijesti datasetu: 1355 Oblik slika iz skupa vijesti: (256, 256, 3). Rezolucija: 256x256, RGB slika Broj podataka u lzn datasetu: 860 Oblik slika iz skupa lzn: (256, 256, 3). Rezolucija: 256x256, RGB slika Broj podataka u haljini datasetu: 800 Broj podataka u simpsons datasetu: 9877
# Prikaz uzorka skupa podataka
def showDatasetSampleGrid(data_directory, num_of_images, grid_size, title):
# Grid dimenzije
fig, axes = plt.subplots(grid_size, grid_size, figsize=(8, 8))
# Učitavanje nasumičnih slika
for ax in axes.flatten():
rand_index = random.randrange(800)
dataset_image = os.listdir(data_directory)[rand_index]
img = cv2.imread(data_directory + dataset_image, cv2.IMREAD_COLOR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
ax.imshow(img)
ax.set_xticks([])
ax.set_yticks([])
ax.set_title(dataset_image.split('.')[0])
plt.suptitle(title)
plt.tight_layout()
plt.show()
# Pokemon dataset
# showDatasetSampleGrid(pokemon_data_directory, 800, 5, 'Pokemon dataset')
# showDatasetSampleGrid(vijesti_data_directory, 1300, 5, 'Vijesti dataset')
showDatasetSampleGrid(lzn_data_directory, 800, 5, 'Lud zbunjen normalan dataset')
# showDatasetSampleGrid(red_dress_directory, 800, 5, 'Crvene haljine dataset')
# showDatasetSampleGrid(simpsons_data_directory, 8000, 5, 'Simpsons dataset')
Referenca: https://proceedings.neurips.cc/paper/2014/file/5ca3e9b122f61f8f06494c97b1afccf3-Paper.pdf
def initialize_weights_v2(model):
for m in model.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
nn.init.normal_(m.weight.data, 0.0, 0.02)
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
nn.init.normal_(m.weight.data, 0.0, 0.02)
nn.init.constant_(m.bias.data, 0)
def initialize_weights(model):
for m in model.modules():
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight, 1.0, 0.02)
nn.init.constant_(m.bias, 0)
seed = random.randint(1, 10000)
random.seed(seed)
torch.manual_seed(seed)
<torch._C.Generator at 0x20d56ca1050>
in_channels (input features) i out_channels definišu veličinu težina
class Generator(nn.Module):
def __init__(self, z_dim, img_channels, features_g, resize=False):
super(Generator, self).__init__()
self.resize = resize
if (self.resize == False):
self.gen = nn.Sequential(
self.layer(in_channels=z_dim, out_channels=features_g*64, kernel_size=4, stride=1, padding=0, bn=True), # 1x1
self.layer(in_channels=features_g*64, out_channels=features_g*32, kernel_size=4, stride=2, padding=1, bn=True), # 4x4
self.layer(in_channels=features_g*32, out_channels=features_g*16, kernel_size=4, stride=2, padding=1, bn=True), # 8x8
self.layer(in_channels=features_g*16, out_channels=features_g*8, kernel_size=4, stride=2, padding=1, bn=True), # 16x16
self.layer(in_channels=features_g*8, out_channels=features_g*4, kernel_size=4, stride=2, padding=1, bn=True), #32x32
self.layer(in_channels=features_g*4, out_channels=features_g*2, kernel_size=4, stride=2, padding=1, bn=True), # 64x64
nn.ConvTranspose2d(in_channels=features_g*2, out_channels=img_channels, kernel_size=4, stride=2, padding=1), # 256x256
nn.Tanh() # Nair & Hinton, 2010 -> mapira (normalizira) na vrijednosti pravih slika
)
else:
self.gen = nn.Sequential(
self.layer(in_channels=z_dim, out_channels=features_g*8, kernel_size=4, stride=1, padding=0, bn=True), # 1x1
self.layer(in_channels=features_g*8, out_channels=features_g*4, kernel_size=4, stride=2, padding=1, bn=True), # 4x4
self.layer(in_channels=features_g*4, out_channels=features_g*2, kernel_size=4, stride=2, padding=1, bn=True), # 8x8
self.layer(in_channels=features_g*2, out_channels=features_g, kernel_size=4, stride=2, padding=1, bn=True), # 16x16
# self.layer(in_channels=features_g, out_channels=features_g, kernel_size=4, stride=2, padding=1, bn=True), # 32x32
nn.ConvTranspose2d(in_channels=features_g, out_channels=img_channels, kernel_size=4, stride=2, padding=1, bias=False), # 64x64
nn.Tanh()
)
def layer(self, in_channels, out_channels, kernel_size, stride, padding, bn=True):
if bn:
return nn.Sequential(
nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
nn.BatchNorm2d(num_features=out_channels),
nn.ReLU(True)
)
else:
return nn.Sequential(
nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
nn.ReLU(True)
)
def forward(self, input):
return self.gen(input)
class Discriminator(nn.Module):
def __init__(self, img_channels, features_d, resize=False):
super(Discriminator, self).__init__()
self.resize = resize
if (self.resize == False):
self.disc = nn.Sequential(
# 3x256x256
self.layer(in_channels=img_channels, out_channels=features_d, kernel_size=4, stride=2, padding=1, bn=False), # fdx128x128
self.layer(in_channels=features_d, out_channels=features_d*2, kernel_size=4, stride=2, padding=1, bn=True), # 64x64
self.layer(in_channels=features_d*2, out_channels=features_d*4, kernel_size=4, stride=2, padding=1, bn=True), # 32x32
self.layer(in_channels=features_d*4, out_channels=features_d*8, kernel_size=4, stride=2, padding=1, bn=True), #16x16
self.layer(in_channels=features_d*8, out_channels=features_d*16, kernel_size=4, stride=2, padding=1, bn=True), #8x8
self.layer(in_channels=features_d*16, out_channels=features_d*32, kernel_size=4, stride=2, padding=1, bn=True), # 4x4
nn.Conv2d(in_channels=features_d*32, out_channels=1, kernel_size=4, stride=2, padding=0), # 1x1
# nn.Sigmoid()
)
else:
self.disc = nn.Sequential(
# 3x64x64
self.layer(in_channels=img_channels, out_channels=features_d, kernel_size=4, stride=2, padding=1, bn=False), # 64x64
self.layer(in_channels=features_d, out_channels=features_d*2, kernel_size=4, stride=2, padding=1, bn=True), # 32x32
self.layer(in_channels=features_d*2, out_channels=features_d*4, kernel_size=4, stride=2, padding=1, bn=True), # 16x16
self.layer(in_channels=features_d*4, out_channels=features_d*8, kernel_size=4, stride=2, padding=1, bn=True), #8x8
nn.Conv2d(in_channels=features_d*8, out_channels=1, kernel_size=4, stride=2, padding=0, bias=False), # 4x4
# nn.Sigmoid()
)
def layer(self, in_channels, out_channels, kernel_size, stride, padding, bn=True):
if bn:
return nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
nn.BatchNorm2d(num_features=out_channels), # Batch normalizacija
nn.LeakyReLU(0.1, inplace=True)
)
return nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
nn.LeakyReLU(0.1, inplace=True)
)
def forward(self, input):
return self.disc(input)
# Testiranje generatora
N, in_channels, H, W = 8, 3, 256, 256 # 3x256x256 slika sa 8 slika u batchu
noise_dim = 100 # z dimenzija
z = torch.randn((N, noise_dim, 1, 1))
generated_image = Generator(noise_dim, in_channels, 8)
image = generated_image(z)[0]
image.permute(1, 2, 0)
image = image.reshape(256,256,3)
plt.imshow((image.detach().numpy() * 255).astype(np.uint8))
plt.axis('off')
(-0.5, 255.5, 255.5, -0.5)
def testGAN(resize=False):
N, in_channels, H, W = 8, 3, 256, 256 # 3x256x256 slika sa 8 slika u batchu
if (resize):
H = 64
W = 64
noise_dim = 100 # z dimenzija
x = torch.randn((N, in_channels, H, W))
z = torch.randn((N, noise_dim, 1, 1))
disc = Discriminator(in_channels, 8, resize=resize)
gen = Generator(noise_dim, in_channels, 8, resize=resize)
print(f'Diskriminator oblik: {disc(x).shape}')
assert disc(x).shape == (N, 1, 1, 1), "Diskriminator nije prošao test"
print(f'Generator oblik: {gen(z).shape}')
assert gen(z).shape == (N, in_channels, H, W), "Generator nije prošao test"
print("Uspješno!")
testGAN()
Diskriminator oblik: torch.Size([8, 1, 1, 1]) Generator oblik: torch.Size([8, 3, 256, 256]) Uspješno!
# https://pytorch.org/tutorials/beginner/basics/data_tutorial.html
class LoadDataset(torch.utils.data.Dataset):
def __init__(self, img_path, resize):
super(LoadDataset, self).__init__()
self.img_path = img_path
self.resize = resize;
print(f'Putanja: {self.img_path}')
def __len__(self):
return len(os.listdir(self.img_path))
def __getitem__(self, idx):
pth = os.listdir(self.img_path)[idx]
img = cv2.imread(self.img_path + pth, cv2.IMREAD_COLOR)
cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = torch.tensor(img)
img = img.permute(2, 0, 1)
if (self.resize):
img = torchvision.transforms.functional.resize(img, (64,64), interpolation=2)
return img/255.0, 1
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
LEARNING_RATE = 0.0002
BETA = (0.5, 0.999)
BATCH_SIZE = 512 # 16, 256 - previse zahtjeva vremena, 128 - najbolje
# IMAGE_SIZE = 256
IMG_CHANNELS = 3
Z_DIM = 100 #100
FEATURES_DISC = 64
FEATURES_GEN = 64
NUM_EPOCHS = 1000
print(DEVICE)
cuda
# lzn_data_directory, vijesti_data_directory, pokemon_data_directory
dataset = LoadDataset(img_path=training_dataset, resize=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
for test_images, test_labels in dataloader:
sample_image = test_images[0]
sample_label = test_labels[0]
# imgplot = plt.imshow(sample_image.permute(1, 2, 0))
# plt.show()
print(sample_image.shape, sample_label)
Putanja: ./datasets/simpsons/
C:\Users\User\anaconda3\lib\site-packages\torchvision\transforms\functional.py:442: UserWarning: Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum. warnings.warn(
torch.Size([3, 64, 64]) tensor(1) torch.Size([3, 64, 64]) tensor(1) torch.Size([3, 64, 64]) tensor(1) torch.Size([3, 64, 64]) tensor(1) torch.Size([3, 64, 64]) tensor(1) torch.Size([3, 64, 64]) tensor(1) torch.Size([3, 64, 64]) tensor(1) torch.Size([3, 64, 64]) tensor(1) torch.Size([3, 64, 64]) tensor(1) torch.Size([3, 64, 64]) tensor(1) torch.Size([3, 64, 64]) tensor(1) torch.Size([3, 64, 64]) tensor(1) torch.Size([3, 64, 64]) tensor(1) torch.Size([3, 64, 64]) tensor(1) torch.Size([3, 64, 64]) tensor(1) torch.Size([3, 64, 64]) tensor(1) torch.Size([3, 64, 64]) tensor(1) torch.Size([3, 64, 64]) tensor(1) torch.Size([3, 64, 64]) tensor(1) torch.Size([3, 64, 64]) tensor(1)
gen = Generator(Z_DIM, IMG_CHANNELS, FEATURES_GEN, resize=True).to(DEVICE)
disc = Discriminator(IMG_CHANNELS, FEATURES_DISC, resize=True).to(DEVICE)
if (DEVICE.type == 'cuda'):
netG = nn.DataParallel(gen, list(range(1)))
netD = nn.DataParallel(disc, list(range(1)))
initialize_weights(gen)
initialize_weights(disc)
if (DEVICE.type == 'cuda'):
netG.apply(initialize_weights)
netD.apply(initialize_weights)
testGAN(resize=True)
Diskriminator oblik: torch.Size([8, 1, 1, 1]) Generator oblik: torch.Size([8, 3, 64, 64]) Uspješno!
opt_gen = optim.Adam(netG.parameters(), lr=LEARNING_RATE, betas=BETA)
opt_disc = optim.Adam(netD.parameters(), lr=LEARNING_RATE, betas=BETA)
criterion = nn.BCEWithLogitsLoss()
# criterion = nn.BCELoss()
fixed_noise = torch.randn((64, Z_DIM, 1, 1)).to(DEVICE)
gen.train();
disc.train();
print("TRENIRANJE NA DATASET-U: ", training_dataset)
# Optimizacija GAN treniranja
# https://medium.com/@utk.is.here/keep-calm-and-train-a-gan-pitfalls-and-tips-on-training-generative-adversarial-networks-edd529764aa9
# Temp liste za mjerenje rezultata i prikaz slika
img_list = []
epoha_list = []
G_losses = []
D_losses = []
# Tehnika label smoothing - https://towardsdatascience.com/gan-ways-to-improve-gan-performance-acf37f9f59b
smoothing = 0.1
real_label = 1. - smoothing
fake_label = 0. + smoothing
# TODO: dodati i n_kritik (Wasserstein GAN)
# Prikaz slika prilikom treniranja
show_images_while_training = True
save_image_iteration = 10 # 10
######## Treniranje GAN modela
for epoch in range(NUM_EPOCHS):
for i, data in enumerate(dataloader, 0):
# Prolaz kroz grupe slika
# --> Treniraj sa skupom pravih slika
# Nakon svake iteracije postaviti gradijent na 0, jer se u suprotnom gubitak akumulira na listovima
netD.zero_grad()
# Preuzmi informacije pravih slika (slika, velicina, labela sa smoothing)
real_img = data[0].to(DEVICE)
b_size = real_img.size(0)
label = torch.full((b_size,), real_label, dtype=torch.float, device=DEVICE)
# --> Diskriminator se trenira na pravim slimama
output = netD(real_img).view(-1)
# Izračunaj gubitak zavisno od klasifikacije diskriminatora (ovdje se očekuje da diskriminator vrati 1)
errD_real = criterion(output, label)
# Izračunaj gradijente za D i ažuriraj kroz backward, skalar tensor za graf
errD_real.backward()
# --> Treniraj sa skupom lažnih slika
# Generiši grupu latentnih vektora (generator) koristeći veličinu pravih slika
noise = torch.randn(b_size, Z_DIM, 1, 1, device=DEVICE)
# Generiši lažne slike i postavi labelu na 0
fake = netG(noise)
# Iskoristi istu varijablu ali napuni sa lažnom labelom
label.fill_(fake_label)
# --> Klasificiraj lažne slike (detach se mora koristiti da se izbaci iz grafa)
output = netD(fake.detach()).view(-1)
# Izračunaj gubitak za G
errD_fake = criterion(output, label)
# Izračunaj gradijente za G
errD_fake.backward()
# Izračunati gubitak za D kao sumu pravih i lažnih
errD = errD_real + errD_fake
# Ažuriraj optimizator
opt_disc.step()
# --> Generiši slike nakon ažuiranja diskriminatora i Ažuriraj Generator
# Nakon svake iteracije postaviti gradijent na 0 za G
netG.zero_grad()
label.fill_(real_label) # Popuni labele, ovdje su prave
# Provjeri za generisane slike D
output = netD(fake).view(-1)
# Zavisno od klasifikacije, nađi gubitak za G
errG = criterion(output, label)
# Izračunati gradijent
errG.backward()
# Ažuriraj optimizator
opt_gen.step()
G_losses.append(errG.item())
D_losses.append(errD.item())
# Spasi rezultat generatora korištenjem torchvision.utils biblioteke
if (epoch % save_image_iteration == 9) or ((epoch == NUM_EPOCHS-1) and (i == len(dataloader)-1)):
with torch.no_grad():
fake = netG(fixed_noise).detach().cpu()
# detach se koristi kad se ne koristi gradijent, pošto Pytorch spašava sve tensore kroz direktni graf
img_list.append(torch.flip(vutils.make_grid(fake, padding=2, normalize=True), [-1]))
epoha_list.append(epoch+1)
# prikaži slike prilikom treniranja (za veće epohe, ukoliko dođe do zastoja)
if (show_images_while_training):
image = torch.flip(vutils.make_grid(fake, padding=2, normalize=True).permute(1,2,0), [-1])
plt.imshow(image)
plt.show()
print('Epoha [{:d}/{:d}] -> d_loss: {:6.4f} | g_loss: {:6.4f}'.format(
epoch+1, NUM_EPOCHS, errD.item(), errG.item()))
print("Treniranje završeno.")
TRENIRANJE NA DATASET-U: ./datasets/simpsons/
C:\Users\User\anaconda3\lib\site-packages\torchvision\transforms\functional.py:442: UserWarning: Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum. warnings.warn(
Epoha [1/1000] -> d_loss: 0.7437 | g_loss: 1.2477 Epoha [2/1000] -> d_loss: 0.7466 | g_loss: 2.7567 Epoha [3/1000] -> d_loss: 0.9323 | g_loss: 0.9265 Epoha [4/1000] -> d_loss: 0.7425 | g_loss: 1.4641 Epoha [5/1000] -> d_loss: 0.7188 | g_loss: 1.6329 Epoha [6/1000] -> d_loss: 1.0182 | g_loss: 3.3703 Epoha [7/1000] -> d_loss: 1.1227 | g_loss: 0.9412 Epoha [8/1000] -> d_loss: 0.7978 | g_loss: 2.2107 Epoha [9/1000] -> d_loss: 1.3784 | g_loss: 3.9970
Epoha [10/1000] -> d_loss: 0.7825 | g_loss: 1.3030 Epoha [11/1000] -> d_loss: 0.7777 | g_loss: 2.7425 Epoha [12/1000] -> d_loss: 0.6999 | g_loss: 2.1127 Epoha [13/1000] -> d_loss: 0.7047 | g_loss: 1.6068 Epoha [14/1000] -> d_loss: 1.0943 | g_loss: 3.7395 Epoha [15/1000] -> d_loss: 1.0073 | g_loss: 3.3448 Epoha [16/1000] -> d_loss: 0.7172 | g_loss: 2.0708 Epoha [17/1000] -> d_loss: 0.6821 | g_loss: 2.1056 Epoha [18/1000] -> d_loss: 0.8286 | g_loss: 2.0734 Epoha [19/1000] -> d_loss: 0.8047 | g_loss: 1.3893
Epoha [20/1000] -> d_loss: 0.7147 | g_loss: 1.8753
--------------------------------------------------------------------------- KeyboardInterrupt Traceback (most recent call last) ~\AppData\Local\Temp\ipykernel_36428\1364346884.py in <module> 21 ######## Treniranje GAN modela 22 for epoch in range(NUM_EPOCHS): ---> 23 for i, data in enumerate(dataloader, 0): 24 # Prolaz kroz grupe slika 25 ~\anaconda3\lib\site-packages\torch\utils\data\dataloader.py in __next__(self) 626 # TODO(https://github.com/pytorch/pytorch/issues/76750) 627 self._reset() # type: ignore[call-arg] --> 628 data = self._next_data() 629 self._num_yielded += 1 630 if self._dataset_kind == _DatasetKind.Iterable and \ ~\anaconda3\lib\site-packages\torch\utils\data\dataloader.py in _next_data(self) 669 def _next_data(self): 670 index = self._next_index() # may raise StopIteration --> 671 data = self._dataset_fetcher.fetch(index) # may raise StopIteration 672 if self._pin_memory: 673 data = _utils.pin_memory.pin_memory(data, self._pin_memory_device) ~\anaconda3\lib\site-packages\torch\utils\data\_utils\fetch.py in fetch(self, possibly_batched_index) 56 data = self.dataset.__getitems__(possibly_batched_index) 57 else: ---> 58 data = [self.dataset[idx] for idx in possibly_batched_index] 59 else: 60 data = self.dataset[possibly_batched_index] ~\anaconda3\lib\site-packages\torch\utils\data\_utils\fetch.py in <listcomp>(.0) 56 data = self.dataset.__getitems__(possibly_batched_index) 57 else: ---> 58 data = [self.dataset[idx] for idx in possibly_batched_index] 59 else: 60 data = self.dataset[possibly_batched_index] ~\AppData\Local\Temp\ipykernel_36428\3590488943.py in __getitem__(self, idx) 11 12 def __getitem__(self, idx): ---> 13 pth = os.listdir(self.img_path)[idx] 14 img = cv2.imread(self.img_path + pth, cv2.IMREAD_COLOR) 15 cv2.cvtColor(img, cv2.COLOR_BGR2RGB) KeyboardInterrupt:
# Grafički prikaz gubitaka
plt.figure(figsize=(10,5))
plt.title("Gubici prilikom treniranja")
plt.plot(G_losses,label="Generator")
plt.plot(D_losses,label="Diskriminator")
plt.xlabel("Iteracije")
plt.ylabel("Gubitak")
plt.legend()
plt.show()
# Prikaz generisanih slika kroz epohe
fig, ax = plt.subplots()
plt.axis("off")
container = []
generated_images = [[plt.imshow(torch.flip(np.transpose(i,(1,2,0)), [-1]), animated=True)] for i in img_list]
for i in range(len(epoha_list)):
image_grid = generated_images[i][0]
title = ax.text(0.5,1.05,"Epoha {}".format(epoha_list[i]),
size=plt.rcParams["axes.titlesize"],
ha="center", transform=ax.transAxes, )
container.append([image_grid, title])
ani = animation.ArtistAnimation(fig, container, interval=700, repeat_delay=700, blit=False)
HTML(ani.to_jshtml())
Model se spašava kroz funkciju torch.save gdje se definiše putanja. Zatim je moguće pokrenuti istrenirani generator i korististi za generisanje slika
Slike se generišu kroz generator koji vraća tensor niz koji predstavlja sliku. Kako bi se slika prikaza i sačuvala potrebno je transponírati sliku korištenjem procesora
# torch.save(gen.state_dict(), 'generator.pth')
N, in_channels, H, W = 8, 3, 64, 64 # 3x256x256 slika sa 8 slika u batchu
noise_dim = 100 # z dimenzija
z = torch.randn((N, noise_dim, 1, 1))
generated_image = np.transpose(vutils.make_grid(netG(z).detach().cpu()),(1,2,0))
print(generated_image.shape)
# plt.imshow((fake.detach().cpu().numpy() * 255).astype(np.uint8))
# plt.axis('off')
torch.Size([68, 530, 3])
Hipervarijable: https://ijeee.edu.iq/Papers/Vol18-Issue1/1570796090.pdf
GAN model adaptacija za CGAN: https://www.cs.toronto.edu/~lczhang/321/lec/gan_notes.html
U ovom CGAN modelu koristi se Linear sloj (y = x*AT + b) gdje je X -> ulazni podaci, A -> težina, b -> bias, generator kroz linear sloj povećava dimenziju.
A => dimenzija in_features X out_features
from torchvision import datasets
from torchvision.transforms import ToTensor
training_dataset = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
data_loader = torch.utils.data.DataLoader(training_dataset, batch_size=32, shuffle=True)
class Generator(nn.Module):
def __init__(self):
super().__init__()
# n_klasa = 10 , z_dim = 10
self.label_emb = nn.Embedding(10, 10)
self.gen = nn.Sequential(
nn.Linear(110, 256), # input_feature = 110, out_features = 256
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 1024),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(1024, 784),
nn.Tanh() # normalizacija
)
def forward(self, z, labels):
z = z.view(z.size(0), 100)
c = self.label_emb(labels)
# Ulancanje slike sa labelom
x = torch.cat([z, c], 1)
out = self.gen(x)
return out.view(x.size(0), 28, 28)
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
# n_klasa = 10 , z_dim = 10
self.label_emb = nn.Embedding(10, 10)
self.disc = nn.Sequential(
nn.Linear(794, 1024),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout(0.3), # Overfitanje, oslabiti diskriminator 0.3
nn.Linear(1024, 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Dropout(0.3),
nn.Linear(256, 1),
nn.Sigmoid() # klasifikacija
)
def forward(self, x, labels):
x = x.view(x.size(0), 784)
c = self.label_emb(labels)
x = torch.cat([x, c], 1)
out = self.disc(x)
return out.squeeze()
# https://www.researchgate.net/figure/The-architecture-of-conditional-GANCGAN_fig3_350115869
generator = Generator().to(DEVICE)
discriminator = Discriminator().to(DEVICE)
criterion = nn.BCELoss()
# LR = 0.0002 mozda
opt_disc = optim.Adam(discriminator.parameters(), lr=LEARNING_RATE, betas=BETA)
opt_gen = optim.Adam(generator.parameters(), lr=LEARNING_RATE, betas=BETA)
print("TRENIRANJE NA DATASET-U: ", training_dataset)
BATCH_SIZE = 32
NUM_EPOCHS = 50 # 100 bolje
for epoch in range(NUM_EPOCHS):
# Ovdje pored slika postoje i labele
for i, (images, labels) in enumerate(data_loader):
step = epoch * len(data_loader) + i + 1
real_images = Variable(images).to(DEVICE)
labels = Variable(labels).to(DEVICE)
generator.train()
opt_disc.zero_grad()
# Treniranje prave slike
realD = discriminator(real_images, labels)
errD_real = criterion(realD, Variable(torch.ones(BATCH_SIZE)).to(DEVICE))
# Treniranje lažne slike
z = Variable(torch.randn(BATCH_SIZE, 100)).to(DEVICE)
fake_labels = Variable(torch.LongTensor(np.random.randint(0, 10, BATCH_SIZE))).to(DEVICE)
fake = generator(z, fake_labels)
fakeD = discriminator(fake, fake_labels)
errD_fake = criterion(fakeD, Variable(torch.zeros(BATCH_SIZE)).to(DEVICE))
# Mjerenje greške i ažuriranje optimizatora
errD = errD_real + errD_fake
errD.backward()
opt_disc.step()
# Generiši slike nakon treniranja diskriminatora
opt_gen.zero_grad()
z = Variable(torch.randn(BATCH_SIZE, 100)).to(DEVICE)
fake_labels = Variable(torch.LongTensor(np.random.randint(0, 10, BATCH_SIZE))).to(DEVICE)
fake_images = generator(z, fake_labels)
output = discriminator(fake_images, fake_labels)
errG = criterion(output, Variable(torch.ones(BATCH_SIZE)).to(DEVICE))
errG.backward()
opt_gen.step()
print('Epoha [{:d}/{:d}] -> d_loss: {:6.4f} | g_loss: {:6.4f}'.format(
epoch+1, NUM_EPOCHS, errD.item(), errG.item()))
print("Treniranje završeno.")
TRENIRANJE NA DATASET-U: Dataset FashionMNIST
Number of datapoints: 60000
Root location: data
Split: Train
StandardTransform
Transform: ToTensor()
Epoha [1/50] -> d_loss: 1.2395 | g_loss: 0.8052
Epoha [2/50] -> d_loss: 1.3016 | g_loss: 0.8841
Epoha [3/50] -> d_loss: 1.2923 | g_loss: 0.9236
Epoha [4/50] -> d_loss: 1.2733 | g_loss: 0.8232
Epoha [5/50] -> d_loss: 1.3053 | g_loss: 0.8601
Epoha [6/50] -> d_loss: 1.1837 | g_loss: 0.8448
Epoha [7/50] -> d_loss: 1.2879 | g_loss: 1.0121
Epoha [8/50] -> d_loss: 1.2161 | g_loss: 0.7655
Epoha [9/50] -> d_loss: 1.2055 | g_loss: 1.0355
Epoha [10/50] -> d_loss: 1.3106 | g_loss: 0.8585
Epoha [11/50] -> d_loss: 1.1102 | g_loss: 0.9037
Epoha [12/50] -> d_loss: 1.2440 | g_loss: 0.9412
Epoha [13/50] -> d_loss: 1.1783 | g_loss: 1.1122
Epoha [14/50] -> d_loss: 1.1419 | g_loss: 1.0241
Epoha [15/50] -> d_loss: 1.2759 | g_loss: 0.9626
Epoha [16/50] -> d_loss: 1.1628 | g_loss: 0.9596
Epoha [17/50] -> d_loss: 1.2131 | g_loss: 0.8438
Epoha [18/50] -> d_loss: 1.3995 | g_loss: 0.9768
Epoha [19/50] -> d_loss: 1.1178 | g_loss: 1.0337
Epoha [20/50] -> d_loss: 1.1342 | g_loss: 0.9816
Epoha [21/50] -> d_loss: 1.1998 | g_loss: 0.7885
Epoha [22/50] -> d_loss: 1.2256 | g_loss: 0.7497
Epoha [23/50] -> d_loss: 1.1240 | g_loss: 1.0188
Epoha [24/50] -> d_loss: 1.1099 | g_loss: 0.9433
Epoha [25/50] -> d_loss: 1.1521 | g_loss: 1.0552
Epoha [26/50] -> d_loss: 1.0232 | g_loss: 1.1746
Epoha [27/50] -> d_loss: 1.1246 | g_loss: 0.9809
Epoha [28/50] -> d_loss: 1.1667 | g_loss: 1.1981
Epoha [29/50] -> d_loss: 1.1132 | g_loss: 1.4857
Epoha [30/50] -> d_loss: 1.1716 | g_loss: 1.5477
Epoha [31/50] -> d_loss: 1.0883 | g_loss: 1.1025
Epoha [32/50] -> d_loss: 1.1271 | g_loss: 1.0805
Epoha [33/50] -> d_loss: 1.1692 | g_loss: 1.2308
Epoha [34/50] -> d_loss: 1.3571 | g_loss: 1.2528
Epoha [35/50] -> d_loss: 0.7786 | g_loss: 1.3195
Epoha [36/50] -> d_loss: 0.9004 | g_loss: 1.3807
Epoha [37/50] -> d_loss: 1.0127 | g_loss: 1.0535
Epoha [38/50] -> d_loss: 1.1810 | g_loss: 1.0611
Epoha [39/50] -> d_loss: 0.8524 | g_loss: 1.0565
Epoha [40/50] -> d_loss: 0.8887 | g_loss: 1.0766
Epoha [41/50] -> d_loss: 0.8922 | g_loss: 1.5439
Epoha [42/50] -> d_loss: 0.9416 | g_loss: 1.3468
Epoha [43/50] -> d_loss: 0.8473 | g_loss: 1.8123
Epoha [44/50] -> d_loss: 0.8508 | g_loss: 1.5014
Epoha [45/50] -> d_loss: 0.9337 | g_loss: 1.2128
Epoha [46/50] -> d_loss: 1.0708 | g_loss: 0.9405
Epoha [47/50] -> d_loss: 1.1823 | g_loss: 1.1478
Epoha [48/50] -> d_loss: 1.0394 | g_loss: 1.4508
Epoha [49/50] -> d_loss: 0.9898 | g_loss: 1.3145
Epoha [50/50] -> d_loss: 1.0058 | g_loss: 1.4612
Treniranje završeno.
def generate_image(generator, img_class):
z = torch.randn(1, 100, device=DEVICE)
label = torch.full((1,), img_class, dtype=torch.long, device=DEVICE)
img = generator(z, label).data.cpu()
img = 0.5 * img + 0.5
return np.transpose(vutils.make_grid(img),(1,2,0))
for i in range(10):
gen_img = generate_image(generator, i)
plt.figure(figsize = (2,2))
plt.axis('off')
plt.imshow((gen_img.numpy() * 255).astype(np.uint8), aspect='auto')